package com.za.plugin.transfer.form.insertupdate;

import cn.hutool.core.collection.CollectionUtil;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.za.plugin.util.SqlUtil;
import com.za.plugin.util.StrUtil;
import org.apache.ibatis.mapping.ParameterMapping;

import java.util.*;
import java.util.stream.Collectors;

public class UpdateStyleTransfer implements StyleTransfer {
    @Override
    public boolean isSupport(String sql) {
        return sql.trim().startsWith("update");
    }

    @Override
    public String transfer(String sql, String tableName, List<String> autoIncrProperties,
                           List<ParameterMapping> parameterMappingsCopy,
                           List<ParameterMapping> parameterMappings, Set<String> pKs
            , Map<String, List<String>> pkAndUniqueKeys) {

        MySqlStatementParser parser = new MySqlStatementParser(sql.toLowerCase());
        SQLStatement sqlStatement = parser.parseStatement();
        if (sqlStatement instanceof MySqlUpdateStatement) {
            MySqlUpdateStatement stat = (MySqlUpdateStatement) sqlStatement;
            int count = StrUtil.getCount(SQLUtils.toMySqlString(stat), "?");
            List<SQLUpdateSetItem> items = stat.getItems();
            if (CollectionUtil.isNotEmpty(items)) {
                List<String> columns = items.stream().filter(Objects::nonNull).map(SQLUpdateSetItem::getColumn)
                        .map(SQLUtils::toMySqlString)
                        .collect(Collectors.toList());
                boolean nonAutoIncr = false;
                for (String column : columns) {
                    if (isContains(column, autoIncrProperties)) {
                        nonAutoIncr = true;
                    }
                }
                if (!nonAutoIncr) {
                    parameterMappings.addAll(parameterMappingsCopy);
                    return sql;
                }

            }
        }

        Map<String, ParameterMapping> map = SqlUtil.getPropertyMapFromParameterMappings(parameterMappingsCopy);


//        boolean isForEach = SqlUtil.isForEach(parameterMappingsCopy);
//        !isForEach &&
        if (sql.startsWith("update")) {
            // 根据 update 的情况获取sqlMap
            Map<String, ParameterMapping> sqlMap = initSqlMap(sql, parameterMappingsCopy);

            // alias 可能是别名也可能是 update tbl [ tblAlias ] [ left join ] ... set 这种情况,
            // 可能有表别名，虽然它可能没有别名的
            String aliasOrJoin = getAliasOrJoin(sql, tableName);

            tableName = SqlUtil.getTableName(tableName);
            // 可能会有 update xxx;update xxx; ...这种多条语句的情况
            int count = StrUtil.getCount(sql, "update ");
            if (count > 1) {

                String[] split = sql.split(";");
                StringBuilder sb = new StringBuilder();
                int start = 0, size = parameterMappingsCopy.size() / count;
                for (String s : split) {
                    List<ParameterMapping> parameterMappingsTmp = parameterMappingsCopy.subList(start, size);
                    start += size;
                    sqlMap = initSqlMap2(sqlMap, parameterMappingsTmp);
                    sb.append("UPDATE ").append(tableName).append(" ").append(aliasOrJoin).append(" set ")
                            .append(getUpdateSql(s, map, parameterMappings, pKs, sqlMap, parameterMappingsCopy,
                                    aliasOrJoin, autoIncrProperties)).append(";");
                }
                sql = sb.deleteCharAt(sb.length() - 1).toString();
            } else {
                List<String> propertyList = SqlUtil.getKeyNameOnUpdate(sql);
//                isUpdateMultiTbl(propertyList);
                // sqlMap 有问题
                sql = "UPDATE " + tableName + " " + aliasOrJoin + " set " +
                        getUpdateSql(sql, map, parameterMappings, pKs, sqlMap, parameterMappingsCopy,
                                aliasOrJoin, autoIncrProperties);
            }
        }
        return sql;
    }

    private static boolean isContains(String column, List<String> set) {
        column = column.replace(" ", "");
        if (set.contains(column)) {
            return true;
        }
        for (String s : set) {
            if (column.contains("." + s)) {
                return true;
            }
        }
        return false;
    }

    private Map<String, ParameterMapping> initSqlMap2(Map<String, ParameterMapping> map, List<ParameterMapping> parameterMappingsCopy) {
        for (ParameterMapping parameterMapping : parameterMappingsCopy) {
            if (parameterMapping.getProperty().contains(".")) {
                int idx = parameterMapping.getProperty().lastIndexOf(".");
                String prop = parameterMapping.getProperty().substring(idx);
                if (map.containsKey(prop)) {
                    map.put(prop, parameterMapping);
                }
                if (map.containsKey(StrUtil.toSqlProperty(prop))) {
                    map.put(StrUtil.toSqlProperty(prop), parameterMapping);
                }
                if (map.containsKey(StrUtil.toJavaProperty(prop))) {
                    map.put(StrUtil.toJavaProperty(prop), parameterMapping);
                }
            } else {
                String prop = parameterMapping.getProperty();
                if (map.containsKey(prop)) {
                    map.put(prop, parameterMapping);
                }
                if (map.containsKey(StrUtil.toSqlProperty(prop))) {
                    map.put(StrUtil.toSqlProperty(prop), parameterMapping);
                }
                if (map.containsKey(StrUtil.toJavaProperty(prop))) {
                    map.put(StrUtil.toJavaProperty(prop), parameterMapping);
                }
            }
        }
        return map;
    }

    private Map<String, List<String>> isUpdateMultiTbl(List<String> propertyList) {
        Map<String, List<String>> map = new HashMap<>();
        for (String property : propertyList) {
            if (property.contains(".")) {
                String prefix = property.split("\\.")[0];
                map.putIfAbsent(prefix, new ArrayList<>());
                map.get(prefix).add(property);
            } else {
                map.putIfAbsent("", new ArrayList<>());
                map.get("").add(property);
            }
        }
        return map;
    }

    private String getAliasOrJoin(String sql, String tableName) {
        sql = sql.toLowerCase().trim();
        //  ', desc=?'  变为 ',`desc`=?',后者可以正常解析，但在达梦中还是会报错
//        if (!sql.contains("desc")){
//            MySqlStatementParser parser = new MySqlStatementParser(sql);
//            MySqlUpdateStatement updateStatement = (MySqlUpdateStatement) parser.parseStatement();
//            SQLTableSource tableSource = updateStatement.getTableSource();
//            return tableSource.getAlias() == null ? "" : tableSource.getAlias().trim();
//        }
        int setIdx = sql.indexOf(" set ");
        if (setIdx == -1) {
            setIdx = sql.indexOf(" where ");
            if (setIdx == -1) {
                setIdx = sql.length() - 1;
            }
        }
        tableName = tableName.replace("\"", "").trim().toLowerCase();
        int tblNameIdx = sql.indexOf(tableName);
        return sql.substring(tblNameIdx + tableName.length(), setIdx + 1).trim();

    }

    // 下面方法存在问题 ， 比如 update frp_produce.tbl_sell_record_animal set sell_group_id =case when id=? then ? when id=? then ? end where id in ( ? , ? )
    private Map<String, ParameterMapping> initSqlMap(String sql, List<ParameterMapping> parameterMappingsCopy) {
        Map<String, ParameterMapping> sqlMap = new LinkedHashMap<>();


        int len = sql.length();
        for (int i = 0, j = 0; i < len; i++) {
            if (sql.charAt(i) == '?') {
                sqlMap.put(searchSqlProperty(sql, i).trim(), parameterMappingsCopy.get(j++));
            }
        }
        return sqlMap;
    }

    private String searchSqlProperty(String sql, int start) {
        int j = start - 1;
        StringBuilder sb = new StringBuilder();
        while (j >= 0) {
            if (sql.charAt(j) == '=') {
                j--;
                continue;
            }
            if (!StrUtil.isLetter(sql.charAt(j)) && sb.length() > 0) {
                return sb.reverse().toString();
            }
            if (sql.charAt(j) != ' ') {
                sb.append(sql.charAt(j));
            }
            j--;
        }
        return "";
    }


    private String getUpdateSql(String sql, Map<String, ParameterMapping> map,
                                List<ParameterMapping> parameterMappings, Set<String> pKs,
                                Map<String, ParameterMapping> sqlMap, List<ParameterMapping> parameterMappingsCopy,
                                String alias, List<String> autoIncrProperties) {
        StringBuilder sb = new StringBuilder(" ");

        Map<String, String> setStrMap = getSetStrList(sql);
        Map<String, String> functionValMap = SqlUtil.getFunctionValueBatch(sql, autoIncrProperties);
        int j = 0;
        Map<String, String> keyValueMapOnUpdate = getKeyValueMapOnUpdate(sql);

        for (Map.Entry<String, String> entry : keyValueMapOnUpdate.entrySet()) {
            String property = entry.getKey();
            String value = entry.getValue();
            if (autoIncrProperties.contains(property)) {
                j++;
                continue;
            }
            if (functionValMap.containsKey(property)) {
                sb.append(StrUtil.toSqlProperty(property)).append("=").append(functionValMap.get(property)).append(",");
            } else {
                // 解决 update tbl_company_no_gen set qrcode_no= qrcode_no + ? where company_id=?
                if (sqlMap.containsKey(property)) {
                    parameterMappings.add(sqlMap.get(property));
                    j++;
                } else if (map.containsKey(StrUtil.toJavaProperty(property))) {
                    parameterMappings.add(map.get(StrUtil.toJavaProperty(property)));
                    j++;
                } else {
                    int count = StrUtil.getCount(value, "?");
                    while (count > 0) {
                        parameterMappings.add(parameterMappingsCopy.get(j));
                        count--;
                        j++;
                    }
                    // 存在 sell_group_id =case when id=? then ? when id=? then ? end 这种情况
                    // 没办法的办法 qrcode_no= qrcode_no + ? ，我们通过次序可以知道 ‘？’ 位置可能是哪个，但没有明确

                    System.err.println(property + "可能没有被设置上,属性为空！！！！");
                }
                if (!SqlUtil.KEY_WORD_SET.contains(property)) {
                    sb.append(StrUtil.toSqlProperty(property));
                } else {
                    sb.append("\"").append(StrUtil.toSqlProperty(property).toUpperCase()).append("\"");
                }

                sb.append("=").append(setStrMap.get(StrUtil.toSqlProperty(property))).append(",");
            }

        }
        int right = sql.indexOf("where");
        if (right == -1) {
            right = sql.length();
        }

        Map<String, String> wherePropertiesMap = getWherePropertiesOnUpdate(sql);
        for (Map.Entry<String, String> entry : wherePropertiesMap.entrySet()) {
            String property = entry.getKey();
            String value = entry.getValue();
            if (property.contains(".")) {
                property = property.split("\\.")[1].trim();
            }
            if (sqlMap.containsKey(property) && StrUtil.getCount(value, "?") == 1) {
                parameterMappings.add(sqlMap.get(property));
                j++;
            } else if (map.containsKey(StrUtil.toJavaProperty(property)) && StrUtil.getCount(value, "?") == 1) {
                parameterMappings.add(map.get(StrUtil.toJavaProperty(property)));
                j++;
            } else {
                System.err.println("没设置到属性:" + property);

                int count = StrUtil.getCount(value, "?");
                while (count > 0) {
                    parameterMappings.add(parameterMappingsCopy.get(j));
                    count--;
                    j++;
                }
            }

        }
        return sb.deleteCharAt(sb.length() - 1) + " " + sql.substring(right);
    }

    private Map<String, String> getKeyValueMapOnUpdate(String sql) {
        Map<String, String> map = new LinkedHashMap<>();
        int left = sql.indexOf(" set ");
        int right = sql.indexOf(" where ", left);
        // 可能没有 where 部分
        if (right == -1) {
            right = sql.length();
        }
        List<String> segments = StrUtil.getSegment(sql, left + " set ".length(), right).stream()
                .map(it -> it.replaceAll("\\s+", " "))
                .filter(it -> it.length() > 0).collect(Collectors.toList());
        for (String segment : segments) {
            int idx = segment.indexOf("=");
            map.put(segment.substring(0, idx).trim(), segment.substring(idx + 1).trim());
        }
        return map;

    }

    private Map<String, String> getSetStrList(String sql) {
        int left = sql.indexOf(" set ") + 5;
        int right = sql.indexOf("where ", left);
        if (right == -1) {
            right = sql.length();
        }
        List<String> segmentList = StrUtil.getSegment(sql, left, right).stream()
                .map(it -> it.replaceAll("\\s+", " "))
                .filter(it -> it.length() > 0).collect(Collectors.toList());
        // 存在 sell_group_id =case when id=? then ? when id=? then ? end 这种情况
        return segmentList.stream().collect(Collectors.toMap(it -> it.split("=")[0].trim()
                , it -> it.substring(it.indexOf("=") + 1).trim()));
    }

    private Map<String, String> getWherePropertiesOnUpdate(String sql) {
        int left = sql.indexOf("where ");
        if (left == -1) {
            return Collections.emptyMap();
        }
        String segment = sql.substring(left + "where ".length()).trim();
        segment = segment.replaceAll("\\s*\\=\\s*", "=");
        List<String> list = new ArrayList<>();

        // 在 update 语句 where 部分，一般使用 and
        Map<String, String> map = new LinkedHashMap<>();
        if (segment.contains(" and ")) {
            list.addAll(Arrays.stream(segment.split(" and ")).collect(Collectors.toList()));
        } else {
            list.add(segment);
        }
        for (String str : list) {
            int index = str.indexOf("=");
            if (str.contains("=")) {
//                str = str.replaceAll("\\s*\\=\\s*", "=");
                map.put(str.substring(0, index).trim(), str.substring(index + 1).trim());
            } else if (index == -1 && str.length() > 0 && str.contains(" in")) {
                int inIdx = str.indexOf(" in");
                map.put(str.substring(0, inIdx).trim(), str.substring(inIdx + 3).trim());
            }
        }
        return map;
//
        // 把 id = ? and add = 'sd' 变为 id=? and add='sd'
//        segment = segment.replaceAll("\\s*\\=\\s*", "=");
//        int index = segment.indexOf("=");
        // where id in (?,?) ，如果没有 ‘=’,大概率where 只有一个条件,如果是 “not in” ？？

//        segment = segment.replaceAll("\\s*=\\s*", "=");
//        while (index >= 0) {
//            int j = index;
//            while (j >= 0 && segment.charAt(j) != ' ') {
//                j--;
//            }
//            list.add(segment.substring(j + 1, index));
//            index = segment.indexOf("=", index + 1);
//        }
//        return list;
    }


}
